"""
Train with adversarial training (AT)
"""
from __future__ import print_function

import os
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import get_model, get_root_path, get_data_loaders
from advertorch.context import ctx_noparamgrad_and_eval


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train MNIST')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--cluster', default="om", help="om | bt | beluga")
    parser.add_argument('--dataset', default="mnist", help="mnist")
    parser.add_argument('--mode', default="adv", help="cln | adv")
    parser.add_argument('--train_batch_size', default=128, type=int)
    parser.add_argument('--test_batch_size', default=1000, type=int)
    parser.add_argument('--log_interval', default=200, type=int)
    parser.add_argument('--enc_model', default='resnet10', type=str)
    parser.add_argument('--encoder_loadpath',
                        default=None,
                        type=str)
    args = parser.parse_args()

    exp_name = f'adv_training_{args.enc_model}'
    ROOT_PATH = get_root_path(args.cluster)
    DATA_PATH = os.path.join(ROOT_PATH, 'Data', args.dataset)
    TRAINED_MODEL_PATH = os.path.join(f'/braintree/data2/active/users/bashivan/results/advil/trained_models/{args.dataset}',
                                      exp_name)
    if not os.path.exists(TRAINED_MODEL_PATH):
        os.makedirs(TRAINED_MODEL_PATH)

    torch.manual_seed(args.seed)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if args.mode == "cln":
        flag_advtrain = False
        if args.dataset == 'mnist':
            nb_epoch = 20
        else:
            nb_epoch = 20
        model_filename = "mnist_lenet5_clntrained.pt"
    elif args.mode == "adv":
        flag_advtrain = True
        if args.dataset == 'mnist':
            nb_epoch = 90
        else:
            nb_epoch = 150
        model_filename = "mnist_lenet5_advtrained.pt"
    else:
        raise ValueError

    if (args.dataset == 'mnist') or (args.dataset == 'fashist'):
        train_loader, test_loader = get_data_loaders(args.dataset,
                                                     args.train_batch_size, args.test_batch_size, DATA_PATH)
    else:
        raise ValueError

    # modelrer
    E, Dc = get_model('mnist', args.enc_model, num_decoder_features=200)
    model = nn.Sequential(E, Dc)

    if (args.encoder_loadpath is not None) and (args.encoder_loadpath != ''):
        print('==> Resuming from checkpoint..')
        print(f'from {args.encoder_loadpath}')
        chkpt = torch.load(args.encoder_loadpath)
        chkpt[0].update(chkpt[1])
        model.load_state_dict(chkpt[0])
    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    Elr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [40, 80], gamma=0.1, last_epoch=-1)

    if flag_advtrain:
        from advertorch.attacks import LinfPGDAttack
        adversary = LinfPGDAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3,
            nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0,
            clip_max=1.0, targeted=False)

    for epoch in range(nb_epoch):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            ori = data
            if flag_advtrain:
                # when performing attack, the model needs to be in eval mode
                # also the parameters should NOT be accumulating gradients
                with ctx_noparamgrad_and_eval(model):
                    data = adversary.perturb(data, target)

            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(
                output, target, reduction='elementwise_mean')
            loss.backward()
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx *
                    len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))

        model.eval()
        test_clnloss = 0
        clncorrect = 0

        if flag_advtrain:
            test_advloss = 0
            advcorrect = 0

        for clndata, target in test_loader:
            clndata, target = clndata.to(device), target.to(device)
            with torch.no_grad():
                output = model(clndata)
            test_clnloss += F.cross_entropy(
                output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            clncorrect += pred.eq(target.view_as(pred)).sum().item()

            if flag_advtrain:
                advdata = adversary.perturb(clndata, target)
                with torch.no_grad():
                    output = model(advdata)
                test_advloss += F.cross_entropy(
                    output, target, reduction='sum').item()
                pred = output.max(1, keepdim=True)[1]
                advcorrect += pred.eq(target.view_as(pred)).sum().item()

        test_clnloss /= len(test_loader.dataset)
        print('\nTest set: avg cln loss: {:.4f},'
              ' cln acc: {}/{} ({:.0f}%)\n'.format(
                  test_clnloss, clncorrect, len(test_loader.dataset),
                  100. * clncorrect / len(test_loader.dataset)))
        if flag_advtrain:
            test_advloss /= len(test_loader.dataset)
            print('Test set: avg adv loss: {:.4f},'
                  ' adv acc: {}/{} ({:.0f}%)\n'.format(
                      test_advloss, advcorrect, len(test_loader.dataset),
                      100. * advcorrect / len(test_loader.dataset)))
        Elr_scheduler.step()
        torch.save(
            model.state_dict(),
            os.path.join(TRAINED_MODEL_PATH, model_filename))